from torchattacks import CW
import numpy as np
import torch
import os
from ASVBaselineTool.feature_tools import *
import soundfile as sf
import torchaudio
import torch.nn as nn
import torch.optim as optim



#FGSM攻击框架
class CW():
    # 模型  标签  原始数据集路径  对抗样本保存的路径
    def __init__(self,model,labelPath,datasetPath,AdvsetPath,kappa =0 ,c=1e-4,steps=1000):
        #获取label和fname
        with open(labelPath,"r") as f:
            Labels=f.readlines()
        self.AFnames=[item.split(" ")[0] for item in Labels]
        ALabels = [item.split(" ")[-1].strip() for item in Labels]
        #label meta
        self.LabelMeta={}
        for f,l in zip(self.AFnames,ALabels):
            if l == "genuine":
                l = torch.Tensor([0.0,1.0])
            elif l == "fake":
                l = torch.Tensor([1.0,0.0])
            self.LabelMeta[f]=l
        #保存的路径 和上面是同样索引
        self.model = model
        self.model.cuda()
        self.model.eval()

        self.datasetPath=datasetPath
        #保存对抗样本路径
        self.AdvSavePath=AdvsetPath
        self.c = c
        self.kappa = kappa
        self.steps = steps
    def feature_select(self,feature,wav:torch.Tensor):
        if feature=="spec_2048":
            return get_SPEC_2048(wav)
        elif feature=="mfcc_40":
            return get_MFCC_40(wav)
        elif feature=="lfcc_70":
            return get_LFCC_70(wav)
        elif feature=="raw":
            wav=wav.reshape(1,-1)
            return wav

    def model_bn_to_eval(self):
        self.model.train()
        self.model.first_bn.eval()
        self.model.block0[0].bn2.eval()
        self.model.block1[0].bn1.eval()
        self.model.block1[0].bn2.eval()
        self.model.block2[0].bn1.eval()
        self.model.block2[0].bn2.eval()
        self.model.block3[0].bn1.eval()
        self.model.block3[0].bn2.eval()
        self.model.block4[0].bn1.eval()
        self.model.block4[0].bn2.eval()
        self.model.block5[0].bn1.eval()
        self.model.block5[0].bn2.eval()
        self.model.bn_before_gru.eval()

    def Attack(self,feature:str,lr=0.0001):
        print("attack alpha:"+str(lr))
        num_f=len(self.AFnames)
        print("总样本：" + str(num_f))
        #攻击成功的样本的转移率
        num_as=0
        for item in self.AFnames:
            src_wav,sr=sf.read(os.path.join(self.datasetPath,item))
            src_wav=torch.Tensor(src_wav)
            if feature=="raw":
                #将bn层固定
                self.model_bn_to_eval()
            adv_wav=src_wav
            adv_feature=self.feature_select(feature,adv_wav)
            adv_feature=adv_feature.cuda()
            #计算预测概率
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            batch_x = self.model(adv_feature)
            init_pred = torch.max(batch_x,dim=1)[1].item()#计算最大预测
            #获取原始分类
            y=self.LabelMeta[item]
            targety = -1 * y + 1
            targety = targety.cuda()
            targety = targety.reshape(1, 2)
            init_y = y.max(0, keepdim=True)[1].item()
            #如果等于就说明本来就分类错误，直接跳过
            if init_pred != init_y:
                print("原始分类错误 pass")
                continue
            #分类正确则攻击
            else:
                # 循环攻击核心代码PGD

                w = torch.arctanh(src_wav).detach()
                # w = src_wav.detach()
                w.requires_grad = True

                best_adv_wav = src_wav.clone().detach()
                best_L2 = 1e10 * torch.ones(len(src_wav)).cuda()
                prev_cost = 1e10
                dim = len(src_wav.shape)

                MSELoss = nn.MSELoss(reduction='none')
                Flatten = nn.Flatten()
                lossfun = nn.CrossEntropyLoss()
                optimizer = optim.Adam([w], lr=lr)
                count = 0
                for i in range(self.steps):

                    perturbed_wav = torch.tanh(w)
                    # perturbed_wav = w

                    current_L2 = MSELoss(Flatten(perturbed_wav.unsqueeze(0)),
                                         Flatten(src_wav.unsqueeze(0))).sum(dim=1)
                    L2_loss = current_L2.sum()

                    adv_feature = self.feature_select(feature, perturbed_wav)
                    adv_feature = adv_feature.cuda()
                    output = self.model(adv_feature)
                    #
                    f_loss = self.f(output)

                    aaa=lossfun(output.detach(),targety)
                    cost = L2_loss.cuda() + self.c*f_loss

                    # 计算损失并梯度反传
                    optimizer.zero_grad()
                    cost.backward()
                    optimizer.step()

                    # _, pre = torch.max(output.detach(), 1)
                    # y=y.cuda()
                    # correct = (pre == y).float()
                    #
                    # mask = (1 - correct) * (best_L2 > current_L2.detach().cuda())
                    #
                    # best_L2 = mask * current_L2.detach().cuda() + (1 - mask) * best_L2
                    #
                    # mask = mask.view([-1] + [1] * (dim - 1))
                    # best_adv_wav = mask * perturbed_wav.detach() + (1 - mask) * best_adv_wav

                    count += 1
                    if self.steps % (self.steps // 10) == 0:
                        if cost.item() > prev_cost:
                            # return best_adv_images, count
                            break
                        prev_cost = cost.item()
                        print(prev_cost)

                print(count)


                #攻击后重新计算特征送入网络判断
                perturbed_wav = best_adv_wav.detach().cpu()
                perturbed_wav.requires_grad=True
                temp_adv_feature=self.feature_select(feature,wav=perturbed_wav)
                temp_adv_feature=temp_adv_feature.to(device)
                temp_adv_feature.cuda()
                self.model.eval()
                temp_batch_x = self.model(temp_adv_feature)
                temp_init_pred = temp_batch_x.max(1)[1].item()
                #如果攻击失败
                if temp_init_pred == init_y:
                    print("攻击失败 pass")
                    continue
                else:
                    sp=os.path.join(self.AdvSavePath, item)
                    print("攻击成功并保存值路径:"+sp)
                    num_as+=1
                    perturbed_wav = perturbed_wav.detach().cpu().numpy()
                    sf.write(sp, perturbed_wav.squeeze(0), sr)
        print("总样本：" + str(num_f))
        print("攻击成功样本数量:"+str(num_as))
        print("攻击成功率:"+str(num_as/num_f))

    def loss(self,x1, x2):
        res = torch.norm(x1 - x2, p=2)
        return res

    def f(self, outputs):
        one_hot_labels = torch.Tensor([0.,1.]).cuda()

        i, _ = torch.max((1 - one_hot_labels) * outputs, dim=1)  # get the second largest logit
        j = torch.masked_select(outputs, one_hot_labels.bool())  # get the largest logit


        return torch.clamp((i - j), min=-self.kappa)













